-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Add tree attention backend for v1 (part 1) #20401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add tree attention backend for v1 (part 1) #20401
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @TheEpicDolphin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates the initial phase of a Tree Attention backend into v1 of the attention system, specifically to support EAGLE speculative decoding. The changes enable the efficient validation of draft tokens by implementing a tree-based attention mechanism that correctly applies necessary attention biases. This work involves significant additions to the attention backend infrastructure, updates to model architecture to utilize the new backend, and includes a correctness test to ensure functionality.
Highlights
- New Tree Attention Backend: Introduced
TreeAttentionBackend
andTreeAttentionImpl
to add support for tree attention, which is a key component for EAGLE speculative decoding in v1 of the attention system. - Attention Bias Implementation: The
TreeAttentionImpl
leveragesxformers.ops.tree_attention
and correctly applies both prefix and speculative (suffix) attention biases, essential for managing attention between draft tokens and their ancestors or prompt tokens. - Dynamic Backend Selection and Draft Model Support: The attention backend selection logic has been updated to include
TREE_ATTN
and now incorporates anis_draft
flag, allowing the system to differentiate and select appropriate attention backends for draft models within the speculative decoding framework. - Optimized Batch Processing: A new
TreeAttentionMetadataBuilder
was added to reorder batches, prioritizing decode requests, and to efficiently construct attention metadata for both prefill (handled by FlashAttention) and speculative decode phases. - Correctness Validation: A new test,
test_tree_attn_correctness
, was implemented to verify the numerical correctness of theTreeAttentionBackend
by comparing its output againstFlashAttentionBackend
across various configurations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new TreeAttentionBackend
for speculative decoding, which is a significant feature addition. The implementation is well-structured, reusing FlashAttentionImpl
for prefill requests and using xformers
for the tree attention part. The new test file provides good coverage for correctness verification.
I've identified a critical issue with duplicated fields in a dataclass and a few medium-severity issues related to code correctness, performance, and maintainability. Addressing these will improve the quality and robustness of the new backend. Overall, this is a great first step towards enabling tree attention.
block_table: torch.Tensor | ||
slot_mapping: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backends=[FlashAttentionBackend, TreeAttentionBackend], | ||
) | ||
assert torch.allclose( | ||
flash_attn_output, tree_attn_output, atol=1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The absolute tolerance atol=1e-2
is a bit high for bfloat16
tensors, which have a machine epsilon of about 7.81e-3
. While this might be necessary due to error accumulation in the attention computation, it would be good to either tighten this tolerance if possible, or add a comment explaining why this level of tolerance is required. A tighter tolerance would give more confidence in the correctness of the implementation.
# Save for next `build` call | ||
# TODO(lucas): this is a bit of a hack, we should probably have a | ||
# better way of doing this | ||
self._num_decodes = num_decodes | ||
self._num_prefills = num_prefills | ||
self._num_decode_tokens = num_decode_tokens | ||
self._num_prefill_tokens = num_prefill_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Storing intermediate state like _num_decodes
, _num_prefills
, etc. on self
between calls to reorder_batch
and build
can make the code harder to reason about and potentially fragile. The TODO
comment acknowledges this.
A cleaner approach might be for reorder_batch
to return this information, and for the caller (in GPUModelRunner
) to pass it to build
. This would make the data flow more explicit and improve maintainability.
For example:
# In TreeAttentionMetadataBuilder
def reorder_batch(...) -> tuple[bool, dict[str, Any]]:
...
reorder_info = {
"num_decodes": num_decodes,
"num_prefills": num_prefills,
...
}
return modified_batch, reorder_info
# In GPUModelRunner
modified_batch, reorder_info = self.attn_metadata_builder.reorder_batch(...)
...
self.attn_metadata = self.attn_metadata_builder.build(..., reorder_info=reorder_info)
# In TreeAttentionMetadataBuilder
def build(..., reorder_info: dict[str, Any]) -> TreeAttentionMetadata:
num_decodes = reorder_info["num_decodes"]
...
Since this would require changes outside of this file, this can be addressed in a follow-up PR.
ancestor_idx = [] | ||
for c in range(len(cur_tree_choice) - 1): | ||
ancestor_idx.append( | ||
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of sorted_tree_choices.index()
inside a loop can lead to quadratic complexity with respect to the number of nodes in the tree. While this is likely not an issue for the small trees currently used in speculative decoding, it could become a performance bottleneck if larger or more complex trees are supported in the future.
Consider pre-computing a mapping from path to index to achieve O(1) lookups. For example:
path_to_idx = {path: i for i, path in enumerate(sorted_tree_choices)}
# ... inside the loop ...
ancestor_idx.append(path_to_idx[cur_tree_choice[: c + 1]] + 1)
This would improve the maintainability and future-proof the code against performance issues with larger trees.
5a37c78
to
bfa883a
Compare
bfa883a
to
3ff7ebe
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
3ff7ebe
to
da6c40b
Compare
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for integrating tree attention! Left a few comments. Regarding the performance, maybe look at the profiles to see what takes the most time - it could be the tree attention itself, but it could also be metadata processing (which we can then take out of decoding loop, at least partially)
device=device, | ||
dtype=torch.int32, | ||
).view(-1, num_allocated_blocks_per_batch) | ||
block_table[:, :num_allocated_blocks_per_batch] = block_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This simulates a situation when pages are actually ordered contiguously in physical memory. Would the test also work in a more complex scenario? For example, you can swap two pages
or even shuffle them all
@@ -1442,6 +1442,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: | |||
"ROCM_AITER_MLA", | |||
"TORCH_SDPA_VLLM_V1", | |||
"FLEX_ATTENTION", | |||
"TREE_ATTN", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: is the comment above "No XFormers so far" still true if you are importing tree attention from xFormers?
@@ -134,7 +134,7 @@ def _get_sliding_window_configs( | |||
sliding_window_configs: set[Optional[tuple[int, int]]] = set() | |||
layers = get_layers_from_vllm_config(vllm_config, Attention) | |||
for layer in layers.values(): | |||
assert isinstance(layer.impl, FlashAttentionImpl) | |||
assert hasattr(layer.impl, "sliding_window") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe assert isinstance(layer.impl, (FlashAttentionImpl, TreeAttentionImpl))
?
return depth_counts | ||
|
||
|
||
def _prepare_tree_attn_bias( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just make this a public API in xFormes instead of duplicating code? https://github.com/facebookresearch/xformers/blob/80250b32516b019b72bb44be04ca9a8741b42faa/xformers/ops/tree_attention.py#L259
spec_v=spec_v, | ||
cache_k=cache_k, | ||
cache_v=cache_v, | ||
prefix_op=triton_splitk.FwOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave prefix_op
as None and rely on the heuristic https://github.com/facebookresearch/xformers/blob/80250b32516b019b72bb44be04ca9a8741b42faa/xformers/ops/tree_attention.py#L469C5-L469C21 to choose the prefix op?
cc @bottler |
0e44c6e
to
0e691d5
Compare
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
0e691d5
to
b82c86b
Compare
Purpose
Add support for tree attention v1 backend. Tree attention is used in EAGLE speculative decoding by the target model to validate a set of draft tokens. Draft tokens only attend to ancestor tokens, and so attention bias must be used to omit attention between non-descendant tokens.
Currently, TreeAttentionImpl is using xformer's tree_attention operation. This operation requires both a prefix and suffix attention bias. The former is used for attention between the draft tokens and the prompt tokens. The latter is used for attention of the draft tokens amongst their ancestors. The two attentions are then merged.
Test Plan
Added test
test_tree_attn_correctness
which verifies that tree attention output for draft chains exactly matches flash attention for the same number of query tokens, for several configurations. This validates the correctness of this backend.Benchmark
In addition, I used the following command to run the LLM service and benchmark TreeAttentionBackend vs FlashAttentionBackend:
Server
Client
Results
This benchmarking helped me verify that this PR did NOT regress performance on v1 spec decoding.
Improvements still need to be made for tree attention. I will investigate further on how to close the gap.
Manual Testing
Used the code below to send a completion request to the vLLM service running with TREE_ATTN backend:
Flash Attention Output
Tree Attention Output
Tree Drafts
I tested generating a tree with the following structure:
Represented by the following list of tuples:
For the input prompt, "Explain the theory of relativity in simple terms.", the backend proposed the following speculative tokens:
Each path in the tree sounds c
NOTE: There is currently no way to sample tokens from a tree, so doing this would currently produce gibberish outputs.
TODOs
The following actions still need to be taken to fully enable this backend:
As of this diff, only chain drafts are supported by TreeAttentionBackend. This is because EagleProposer still only generates draft chains.